# pylint:disable=missing-class-docstring
"""
All type constants used in type inference. They can be mapped, translated, or rewritten to C-style types.
"""
from __future__ import annotations

import functools
import itertools


def memoize(f):
    @functools.wraps(f)
    def wrapped_repr(self, *args, **kwargs):
        memo = set() if not kwargs or "memo" not in kwargs else kwargs.pop("memo")
        if self in memo:
            return "..."
        memo.add(self)
        r = f(self, *args, memo=memo, **kwargs)
        memo.remove(self)
        return r

    return wrapped_repr


class TypeConstant:
    SIZE = None

    def __init__(self, name: str | None = None):
        self.name = name

    def pp_str(self, mapping) -> str:  # pylint:disable=unused-argument
        return repr(self)

    def _hash(self, visited: set[int]):  # pylint:disable=unused-argument
        return hash(type(self))

    def __eq__(self, other):
        return type(self) is type(other)

    def __hash__(self):
        return self._hash(set())

    @property
    def size(self) -> int:
        if self.SIZE is None:
            raise NotImplementedError
        return self.SIZE

    def __repr__(self, memo=None) -> str:
        raise NotImplementedError


class TopType(TypeConstant):
    def __repr__(self, memo=None):
        return "TOP"


class BottomType(TypeConstant):
    def __repr__(self, memo=None):
        return "BOT"


class Int(TypeConstant):
    def __repr__(self, memo=None) -> str:
        return "intbase"


class Int1(Int):
    SIZE = 1


class Int8(Int):
    SIZE = 1

    def __repr__(self, memo=None):
        return "int8"


class Int16(Int):
    SIZE = 2

    def __repr__(self, memo=None):
        return "int16"


class Int32(Int):
    SIZE = 4

    def __repr__(self, memo=None) -> str:
        return "int32"


class Int64(Int):
    SIZE = 8

    def __repr__(self, memo=None) -> str:
        return "int64"


class Int128(Int):
    SIZE = 16

    def __repr__(self, memo=None):
        return "int128"


class Int256(Int):
    SIZE = 32

    def __repr__(self, memo=None):
        return "int256"


class Int512(Int):
    SIZE = 32

    def __repr__(self, memo=None):
        return "int512"


class IntVar(Int):
    def __init__(self, size, name: str | None = None):
        super().__init__(name)
        self._size = size

    @property
    def size(self) -> int:
        return self._size

    def __repr__(self, memo=None):
        return "intvar"


class Float(TypeConstant):
    def __repr__(self, memo=None) -> str:
        return "floatbase"


class Float32(Float):
    SIZE = 4

    def __repr__(self, memo=None):
        return "float32"


class Float64(Float):
    SIZE = 8

    def __repr__(self, memo=None):
        return "float64"


class Pointer(TypeConstant):
    def __init__(self, basetype: TypeConstant | None, name: str | None = None):
        super().__init__(name=name)
        self.basetype: TypeConstant | None = basetype

    def __eq__(self, other):
        return type(self) is type(other) and self.basetype == other.basetype

    def _hash(self, visited: set[int]):
        if self.basetype is None:
            return hash(type(self))
        return hash((type(self), self.basetype._hash(visited)))

    def new(self, basetype):
        return self.__class__(basetype)

    def __hash__(self):
        return self._hash(set())


class Pointer32(Pointer, Int32):
    """
    32-bit pointers.
    """

    def __init__(self, basetype=None, name: str | None = None):
        Pointer.__init__(self, basetype, name=name)
        Int32.__init__(self, name=name)

    @memoize
    def __repr__(self, memo=None):
        bt = self.basetype.__repr__(memo=memo) if isinstance(self.basetype, TypeConstant) else repr(self.basetype)
        name_str = f"{self.name}#" if self.name else ""
        return f"{name_str}ptr32({bt})"


class Pointer64(Pointer, Int64):
    """
    64-bit pointers.
    """

    def __init__(self, basetype=None, name: str | None = None):
        Pointer.__init__(self, basetype, name=name)
        Int64.__init__(self, name=name)

    @memoize
    def __repr__(self, memo=None):
        bt = self.basetype.__repr__(memo=memo) if isinstance(self.basetype, TypeConstant) else repr(self.basetype)
        name_str = f"{self.name}#" if self.name else ""
        return f"{name_str}ptr64({bt})"


class Array(TypeConstant):
    def __init__(self, element=None, count=None, name: str | None = None):
        super().__init__(name=name)
        self.element: TypeConstant | None = element
        self.count: int | None = count

    @property
    def size(self) -> int:
        if not self.count or not self.element:
            return 0
        return self.element.size * self.count

    @memoize
    def __repr__(self, memo=None):
        if self.count is None:
            return f"{self.element!r}[?]"
        return f"{self.element!r}[{self.count}]"

    def __eq__(self, other):
        return type(other) is type(self) and self.element == other.element and self.count == other.count

    def _hash(self, visited: set[int]):
        if id(self) in visited:
            return 0
        visited.add(id(self))
        return hash((type(self), self.element, self.count))

    def __hash__(self):
        return self._hash(set())


_STRUCT_ID = itertools.count()


class Struct(TypeConstant):
    def __init__(self, fields=None, name=None, field_names=None, is_cppclass: bool = False, idx: int = -1):
        super().__init__(name=name)
        self.fields = {} if fields is None else fields  # offset to type
        self.field_names = field_names
        self.is_cppclass = is_cppclass
        self.idx = idx if idx != -1 else next(_STRUCT_ID)

    def _hash(self, visited: set[int]):
        if id(self) in visited:
            return 0
        visited.add(id(self))
        return hash((type(self), self.idx, self._hash_fields(visited)))

    def _hash_fields(self, visited: set[int]):
        keys = sorted(self.fields.keys())
        tpl = tuple((k, self.fields[k]._hash(visited) if self.fields[k] is not None else None) for k in keys)
        return hash(tpl)

    @property
    def size(self) -> int:
        if not self.fields:
            return 0
        max_field_off = max(self.fields.keys())
        return max_field_off + (
            self.fields[max_field_off].size if not isinstance(self.fields[max_field_off], BottomType) else 1
        )

    @memoize
    def __repr__(self, memo=None):
        prefix = "CppClass" if self.is_cppclass else "struct"
        prefix += f"#{self.idx}"
        if self.name:
            prefix = f"{prefix} {self.name}"
        return (
            prefix
            + "{"
            + ", ".join(f"{k}:{v.__repr__(memo=memo) if v is not None else 'None'}" for k, v in self.fields.items())
            + "}"
        )

    def __eq__(self, other):
        return type(other) is type(self) and self.idx == other.idx and hash(self) == hash(other)

    def __hash__(self):
        return self._hash(set())


class Function(TypeConstant):
    def __init__(self, params: list, outputs: list, name: str | None = None):
        super().__init__(name=name)
        self.params = params
        self.outputs = outputs

    @memoize
    def __repr__(self, memo=None):
        param_str = ", ".join(repr(param) for param in self.params)
        outputs_str = ", ".join(repr(output) for output in self.outputs)
        return f"func({param_str}) -> {outputs_str}"

    def __eq__(self, other):
        if not isinstance(other, Function):
            return False
        return self.params == other.params and self.outputs == other.outputs

    def _hash(self, visited: set[int]):
        if id(self) in visited:
            return 0
        visited.add(id(self))

        params_hash = tuple(param._hash(visited) for param in self.params)
        outputs_hash = tuple(out._hash(visited) for out in self.outputs)
        return hash((Function, params_hash, outputs_hash))

    def __hash__(self):
        return self._hash(set())


class TypeVariableReference(TypeConstant):
    def __init__(self, typevar, name: str | None = None):
        super().__init__(name=name)
        self.typevar = typevar

    def __repr__(self, memo=None):
        return f"ref({self.typevar})"

    def __eq__(self, other):
        return type(other) is type(self) and self.typevar == other.typevar

    def __hash__(self):
        return hash((type(self), self.typevar))


#
# Methods
#


def int_type(bits: int) -> Int:
    mapping = {
        1: Int1,
        8: Int8,
        16: Int16,
        32: Int32,
        64: Int64,
        128: Int128,
        256: Int256,
        512: Int512,
    }
    return mapping[bits]() if bits in mapping else IntVar(bits)


def float_type(bits: int) -> Float | None:
    if bits == 32:
        return Float32()
    if bits == 64:
        return Float64()
    return None
